package org.example.graphx.local

import org.apache.log4j.{Logger}
import org.apache.spark.graphx.{Graph, VertexId}
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.sql.SparkSession

object SSSPExample {

  def main(args: Array[String]): Unit = {

    val logger = Logger.getLogger(SSSPExample.getClass.getName)

    // 创建一个SparkSession
    val spark = SparkSession
      .builder
      .master("local")
      .appName("graphx-SSSPAlgo")
      .getOrCreate()
    val sc = spark.sparkContext

    // 生成一个带有距离边属性的图
    val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
    logger.warn("原始图为：")
    println(graph.edges.take(10).mkString("\n"))

    val sourceId: VertexId = 42
    val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)

    val sssp = initialGraph.pregel(Double.PositiveInfinity)(
      (id, dist, newDist) => math.min(dist, newDist), // Vertex Program
      triplet => { // Send Message
        if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
          Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
        } else {
          Iterator.empty
        }
      },
      (a, b) => math.min(a, b) // Merge Message
    )

    logger.warn(sourceId + "到其余各点的最短路径信息为：")
    println(sssp.vertices.collect.mkString("\n"))

    spark.stop()
  }

}
